#!/usr/bin/python
# -*- coding: utf-8 -*-

"""
Prepare dfr-browser data files from MALLET outputs.
An alternative to this script with more features can be found in the
export_browser_data function in this R package:
http://github.com/agoldst/dfrtopics.

The original script has been updated slightly for greater Python 3 compatibility.

Usage:
    prepare-data check [<dir>]
        Check presence and format of data files in <dir> or PWD

    prepare-data info-stub [-o <file>]
        Write stub info.json to <file> or "info.json"

    prepare-data convert-citations ... [-o <file>] [--ids <ids>]
        Convert JSTOR DfR citations.tsv files ... into zipped metadata
        <file>: name of file to write ("meta.csv.zip" by default)
        <ids>: (optional) a file with one document ID per line. If
        supplied, these IDs are matched against the first column of the
        input metadata, and the output contains only matching documents
        (in the order given in <ids>).

    prepare-data convert-tw [-o <file>] <tw> --vocab <v> [--param <p>] [-n N]
        Write topic-word information to <file> or "tw.json"
        <tw>: CSV with topic-word matrix (headerless)
        <v>: vocabulary listing, one line per column of <tw>
        <p>: (optional) params.txt written by dfrtopics v0.2
            If this is missing, topic alpha values will be missing
        <n>: number of top words per topic (50 by default)
        Corrected error in def transform_topic_weights(weights, vocab, n)
        Without this correction only 1% of vocabulary taken into consideration
        Dan Baciu, Jan 17, 2019

    prepare-data convert-keys [-o <file>] <keys>
        Write topic-word information to <file> or "tw.json"
        <keys>: CSV with topic,alpha,word,weight columns (from dfrtopics v0.1)
            or with topic,word,weight columns (from dfrtopics v0.2)

    prepare-data convert-dt [-o <file>] <dt>
        Write zipped document-topic information to <dt> or <dt.json.zip>
        <dt>: CSV with document-word weights (headerless).
            This ordinary matrix is converted to the required sparse format.

    prepare-data convert-state <state> [--tw <tw>] [--dt <dt>]
        Use the MALLET sampling state to write both topic words and document
        topics.
        <state>: gzipped file from mallet train-topics --output-state or
        dfrtopics::write_mallet_state
        <tw>: name of topic-word file to output ("tw.json" by default)
        <dt>: name of doc-topic file to output ("dt.json.zip" by default)
"""

import gzip
import json
import os
import re
import zipfile as zf
from collections import defaultdict


def check_files(d):
    fs = os.listdir(d)

    print("Checking info.json...")
    if "info.json" not in fs:
        print("""info.json not found.
Use "prepare-data info-stub" to create one.
""")
    else:
        with open(os.path.join(d, "info.json")) as f:
            try:
                json.load(f)
                print("info.json: ok")
            except Exception as e:
                print("Unable to parse info.json as JSON. json error:\n" +
                        e.message)

    print("Checking meta.csv.zip...")
    if "meta.csv.zip" in fs:
        with zf.ZipFile(os.path.join(d, "meta.csv.zip")) as z:
            with z.open(z.infolist()[0]) as f:
                meta = f.readlines()
                print("meta.csv.zip: ok")
    else:
        print("""No meta.csv.zip found.
Use "prepare-data convert-citations" on DfR citations.tsv files.
""")

    print("Checking topic_scaled.csv...")
    if "topic_scaled.csv" not in fs:
        print("""No topic_scaled.csv found. This file is required only for the "scaled" overview.
"""
        )
    else:
        print("topic.scaled.csv: ok")

    print("Checking tw.json...")
    if "tw.json" not in fs:
        print("""No tw.json found.
Use "prepare-data convert-tw" on a topic-word matrix CSV or
"prepare-data convert-keys" on a CSV listing top words and weights.
""")
    else:
        with open(os.path.join(d, "tw.json")) as f:
            try:
                tw = json.load(f)
                if "alpha" not in tw:
                    raise Exception("alpha values not present")

                if "tw" not in tw:
                    raise Exception("tw field missing")

                tws = tw["tw"]
                for t in xrange(len(tws)):
                    if "words" not in tws[t]:
                        raise Exception("words missing for topic " + str(t + 1))
                    if "weights" not in tws[t]:
                        raise Exception("weights missing for topic " +
                                str(t + 1))
                print("tw.json: ok")
            except Exception as e:
                print("Problem with tw.json: " + e.message)

    print("Checking dt.json.zip...")
    if "dt.json.zip" not in fs:
        print("""No dt.json.zip found.
Use "prepare-data convert-dt" on a document-topic matrix CSV.
""")
    else:
        with zf.ZipFile(os.path.join(d, "dt.json.zip")) as z:
            with z.open(z.infolist()[0]) as f:
                try:
                    dt = json.load(f)
                    if "i" not in dt or "p" not in dt or "x" not in dt:
                        raise Exception("i, p, or x member missing")
                    if len(meta) > 0:
                        if max(dt["i"]) != len(meta) - 1:
                            raise Exception("doc-topics / metadata mismatch")
                    print("dt.json.zip: ok")
                except Exception as e:
                    print("problem with doc-topics JSON: " + e.message)

def info_stub(out):
    with open(out, "w") as f:
        json.dump({
            "title": "",
            "meta_info": r'<h2></h2>',
            "VIS": { "overview_words": 15 }
            },
            fp=f, indent=4)
        print("Created stub file in " + f.name)

def convert_citations(fs, matchfile, out):
    ll = []
    in_ids = defaultdict(int)
    i = 0
    for f in fs:
        with open(f) as meta:
            meta.readline() # header
            for line in meta:
                if not line.strip():
                    continue  # Ignore lines that contain only whitespace.
                fields = [s.replace('"', '""') for s in line.strip().split("\t")]
                ll.append('"' + fields[0] + '","' + '","'.join(fields[2:9]) +
                        '"')
                in_ids[fields[0]] = i
                i += 1

    if matchfile is not None:
        with open(matchfile) as f:
            out_ids = [s.strip() for s in f]
        meta_out = "\n".join([ll[in_ids[key]] for key in out_ids])
    else:
        meta_out = "\n".join(ll)

    with zf.ZipFile(out, "w") as z:
        z.writestr("meta.csv", meta_out)

    print("Wrote metadata to " + out)

def convert_tw(twf, out, vocabf, paramf, n):
    tw = []
    with open(vocabf) as f:
        vocab = [s.strip() for s in f]

    with open(twf) as f:
        for line in f:
            weights = [int(x) for x in line.strip().split(",")]
            tw.append(transform_topic_weights(weights, vocab, n))

    if paramf is None:
        print("Warning: no parameters file, so alpha_k will be set to zero")
        alpha = [0] * len(tw)
    else:
        with open(paramf) as f:
            p = " ".join([l.strip() for l in f.readlines()])
            m = re.search(r'alpha = c\(([0-9., ]+)\)', p)
            alpha = [float(a) for a in m.group(1).split(", ")]

    write_tw(alpha, tw, out)

def transform_topic_weights(weights, vocab, n):
    words = list(range(len(weights)))
    words.sort(key=lambda i: -weights[i])
    return({
        "words": [vocab[w] for w in words[:n]],
        "weights": [weights[w] for w in words[:n]]
    })

def write_tw(alpha, tw, out):
    twj = {
        "alpha": alpha,
        "tw": tw
    }
    print("TWJ")
    print(twj)
    with open(out, "w") as f:
        json.dump(twj, f)

    print("Wrote topic-words information to " + f.name)

def convert_keys(keysf, out):
    with open(keysf) as f:
        header = f.readline()
        style = len(header.strip().split(","))
        if style == 3:
            print("New-style top topic words: alpha_k missing and left at 0")
            alphas, tw = keys_newstyle(f)
        elif style == 4:
            alphas, tw = keys_oldstyle(f)
        else:
            raise Exception("Unknown top topic words format: expect 3 or 4 cols")

    write_tw(alphas, tw, out)

def keys_oldstyle(f):
    tw = []
    last_topic = 1
    words = []
    weights = []
    alphas = []
    for line in f:
        topic, alpha, word, weight = line.strip().split(",")
        topic = int(topic)
        if topic != last_topic:
            tw.append({ "words": words, "weights": weights })
            words = []
            weights = []
            alphas.append(float(alpha))
        words.append(word)
        weights.append(int(weight))
        last_topic = topic
    tw.append({ "words": words, "weights": weights})
    alphas.append(float(alpha))

    return(alphas, tw)

def keys_newstyle(f):
    tw = []
    last_topic = 1
    words = []
    weights = []
    for line in f:
        topic, word, weight = line.strip().split(",")
        topic = int(topic)
        if topic != last_topic:
            tw.append({ "words": words, "weights": weights })
            words = []
            weights = []
        words.append(word)
        weights.append(int(weight))
        last_topic = topic
    tw.append({ "words": words, "weights": weights})

    # alpha left at zero
    return([0] * len(tw), tw)



def convert_dt(dtf, out):
    with open(dtf) as f:
        d1 = f.readline().strip()
        dt = [[int(x)] for x in d1.split(",")]
        K = len(dt)
        for line in f:
            xs = line.strip().split(",")
            for t in list(range(K)):
                dt[t].append(int(xs[t]))

    write_dt(transform_dt(dt), out)

def transform_dt(dt):
    D = len(dt[0])
    p = [0]
    i = []
    x = []
    p_cur = 0
    for topic_docs in dt:
        for d in list(range(D)):
            if topic_docs[d] != 0:
                i.append(d)
                x.append(topic_docs[d])
                p_cur += 1
        p.append(p_cur)

    return({ "i": i, "p": p, "x": x })

def write_dt(dtj, out):
    with zf.ZipFile(out, "w") as z:
        z.writestr("dt.json", json.dumps(dtj))

    print("Wrote sparse doc-topics to " + out)

def convert_state(state_file, tw_file, dt_file, n=50):
    with gzip.open(state_file, "rb") as f:
        f.readline()
        alpha = list(map(float, f.readline().decode().strip().split(" ")[2:]))
        beta = f.readline().decode().strip().split(" ")[2]
        print(f"beta value, not saved in a file: {beta}")

        # Initialise variables

        # A dict of topic numbers where each topic number is a dict of {typeindex: weight}
        tw = defaultdict(lambda : defaultdict(int))
        # A dict of typeindex: word
        vocab = dict()
        # A list of dicts where each dict is of type {topicnumber: int}
        dt = []
        # A dict of type {topicnumber: int}
        cur_dt = defaultdict(int)
        last_doc = 0    # assume we start at doc 0
        K = 0

        # Iterate through the state file
        for line in f:
            # Split the line and ensure int fields are ints
            doc,source,pos,typeindex,word,topic = line.strip().split()
            doc = int(doc)
            typeindex = int(typeindex)
            topic = int(topic)
            # Set the topic number
            if topic > K:
                K = topic
            # If we're at a new document, save the current document and rest cur_dt
            if last_doc != doc:
                dt.append(cur_dt)
                cur_dt = defaultdict(int)
            # Increment the topic number for cur_dt
            cur_dt[topic] += 1
            # Increment the type index for the topic in tw
            tw[topic][typeindex] += 1
            # Add the word and typeindex to the vocab if it is not already there
            if typeindex not in vocab:
                vocab[typeindex] = word.decode()
            # Set last_doc as the current doc id
            last_doc = doc

        # K is max(topic), but we want it to be number of topics:
        K = K + 1
        # final doc: after end of for loop
        if len(cur_dt) > 0:
            dt.append(cur_dt)

        # Create a list of dicts for each topic where each dict has list of weights and words
        topic_dicts = []
        for topic_num, values in tw.items():
            weights = []
            words = []
            for typeindex, weight in values.items():
                weights.append(weight)
                words.append(vocab[typeindex])
            topic_dicts.append({"weights": weights, "words": words})

    transformed_tw = [transform_topic_weights(topic_dicts[t]["weights"], vocab, n) for t in list(range(K))]
    write_tw(alpha, transformed_tw, tw_file)

    transformed_dt = transform_dt([[d[t] for d in dt] for t in list(range(K))])
    write_dt(transformed_dt, dt_file)


def help():
    print(__doc__)

def key_arg(args, key):
    res = None
    try:
        i = args.index(key)
        res = rest[i + 1]
        del rest[i:i + 2]
    except:
        pass

    return(res)

if __name__=="__main__":
    import sys
    if len(sys.argv) == 1:
        help()
    else:
        cmd = sys.argv[1]

        if len(sys.argv) == 2:
            rest = []
            out = None
        else:
            rest = sys.argv[2:]
            out = key_arg(rest, "-o")

        if cmd == "help":
            help()
        elif cmd == "check":
            if len(rest) > 0:
                check_files(rest[0])
            else:
                check_files(".")
        elif cmd == "info-stub":
            if out is None:
                info_stub("info.json")
            else:
                info_stub(out)
        elif cmd == "convert-citations":
            if out is None:
                out = "meta.csv.zip"
            matchfile = key_arg(rest, "--ids")
            convert_citations(rest, matchfile, out)
        elif cmd == "convert-tw":
            if out is None:
                out = "tw.json"
            vocabf = key_arg(rest, "--vocab")
            if vocabf is None:
                raise Exception(
                        "A vocabulary file must be supplied with --vocab")

            paramf = key_arg(rest, "--param")

            n = key_arg(rest, "-n")
            if n is None:
                n = 50
            convert_tw(rest[0], out, vocabf, paramf, int(n))
        elif cmd == "convert-keys":
            if out is None:
                out = "tw.json"
            convert_keys(rest[0], out)
        elif cmd == "convert-dt":
            if out is None:
                out = "dt.json.zip"
            convert_dt(rest[0], out)
        elif cmd == "convert-state":
            tw = key_arg(rest, "--tw")
            if tw is None:
                tw = "tw.json"
            dt = key_arg(rest, "--dt")
            if dt is None:
                dt = "dt.json.zip"
            n = key_arg(rest, "-n")
            if n is None:
                n = 50
            convert_state(rest[0], tw, dt, n)
        else:
            help()
